package gorgonia

import (
	"fmt"
	"math/rand"
	"testing"

	"github.com/stretchr/testify/require"
	"gorgonia.org/tensor"
)

var groupNormTestCases = []struct {
	Dtype  tensor.Dtype
	X      interface{}
	XShape tensor.Shape

	Groups, Channels int

	ScaleInit  InitWFn
	ScaleShape tensor.Shape

	BiasInit  InitWFn
	BiasShape tensor.Shape

	ExpectedTrainResult, ExpectedCost                      interface{}
	ExpectedBiasGrad, ExpectedScaleGrad, ExpectedInputGrad interface{}
}{
	{
		Dtype:               tensor.Float64,
		X:                   RangedFromWithStep(0.1, 2),
		XShape:              tensor.Shape{2, 20, 2, 2},
		Groups:              2,
		Channels:            20,
		ScaleInit:           RangedFromWithStep(0.3, 0.3),
		ScaleShape:          tensor.Shape{1, 20, 1, 1},
		BiasInit:            RangedFromWithStep(0.2, 0.2),
		BiasShape:           tensor.Shape{1, 20, 1, 1},
		ExpectedTrainResult: []float64{4385.900222723546, 10064.555955139396, 4385.900222723546, 10064.555955139393},
		ExpectedCost:        7225.228088931472,
		ExpectedInputGrad:   []float64{0.1455456629614057, 0.10756983263957326, 0.06959400231774085, 0.03161817199590841, 0.08135482742065983, 0.04532816344875147, 0.009301499476843111, -0.02672516449506525, 0.03275732267930655, -0.0013201749426776932, -0.03539767256466193, -0.06947517018664628, -0.00024685126265393265, -0.03237518253471408, -0.06450351380677433, -0.09663184507883454, -0.017657694405221844, -0.047836859327357845, -0.07801602424949394, -0.10819518917163007, -0.019475206748396964, -0.04770520532060915, -0.07593520389282099, -0.10416520246503308, -0.005699388292179461, -0.03198022051446739, -0.058261052736755414, -0.08454188495904336, 0.023669760963430445, -0.0006619049089335582, -0.02499357078129756, -0.04932523665366134, 0.068632241018433, 0.04624974149599304, 0.023867241973553298, 0.0014847424511133245, 0.12918805187282817, 0.10875471870031223, 0.08832138552779653, 0.06788805235528059, 0.1712175958138913, 0.11393295134647796, 0.05664830687906508, -0.0006363375883475797, 0.10775815768778663, 0.05242267957029734, -0.0029127985471915085, -0.05824827666467991, 0.059892050361074256, 0.006505738593509225, -0.04688057317405536, -0.1002668849416204, 0.027619273833755065, -0.023817871583885708, -0.07525501700152648, -0.12669216241916725, 0.010939828105828608, -0.038548150961888794, -0.08803613002960531, -0.13752410909732182, 0.009853713177293999, -0.0376850995404987, -0.08522391225829185, -0.13276272497608366, 0.02436092904815257, -0.021228717319716317, -0.06681836368758476, -0.11240801005545364, 0.054461475718402985, 0.010820995700458802, -0.032819484317485825, -0.07645996433542956, 0.10015535318804658, 0.05846403952002577, 0.016772725852005843, -0.02491858781601497, 0.16144256145708247, 0.12170041413898636, 0.08195826682088936, 0.042216119502794136, 0.14554566296140425, 0.10756983263957176, 0.06959400231773927, 0.03161817199590722, 0.0813548274206588, 0.045328163448750125, 0.00930149947684189, -0.026725164495066345, 0.032757322679305645, -0.0013201749426783316, -0.03539767256466275, -0.06947517018664673, -0.0002468512626543351, -0.0323751825347145, -0.06450351380677466, -0.09663184507883482, -0.017657694405222024, -0.04783685932735793, -0.07801602424949383, -0.10819518917163018, -0.019475206748396978, -0.04770520532060907, -0.07593520389282071, -0.1041652024650328, -0.005699388292179197, -0.031980220514466584, -0.05826105273675486, -0.08454188495904269, 0.02366976096343132, -0.0006619049089326978, -0.024993570781296715, -0.04932523665366029, 0.06863224101843413, 0.04624974149599437, 0.02386724197355461, 0.001484742451114407, 0.12918805187282967, 0.10875471870031372, 0.08832138552779778, 0.06788805235528228, 0.17121759581388218, 0.1139329513464693, 0.05664830687905731, -0.0006363375883546851, 0.10775815768777974, 0.052422679570290676, -0.0029127985471975038, -0.058248276664684795, 0.059892050361068705, 0.00650573859350434, -0.046880573174060025, -0.10026688494162439, 0.027619273833751734, -0.023817871583888817, -0.07525501700152937, -0.12669216241916992, 0.010939828105827054, -0.03854815096188968, -0.08803613002960642, -0.13752410909732227, 0.009853713177294665, -0.03768509954049826, -0.08522391225829029, -0.13276272497608232, 0.024360929048154567, -0.021228717319713653, -0.06681836368758098, -0.11240801005545009, 0.05446147571840765, 0.010820995700464131, -0.032819484317480274, -0.07645996433542379, 0.10015535318805302, 0.05846403952003243, 0.016772725852013615, -0.024918587816006976, 0.16144256145709068, 0.12170041413899568, 0.0819582668208989, 0.042216119502803906},
		ExpectedBiasGrad:    []float64{51.00000000000007, 55.80000000000008, 60.60000000000008, 65.40000000000009, 70.20000000000009, 75.0000000000001, 79.8000000000001, 84.60000000000008, 89.40000000000006, 94.20000000000005, 99.00000000000003, 103.80000000000001, 108.6, 113.39999999999998, 118.19999999999996, 122.99999999999994, 127.79999999999993, 132.5999999999999, 137.3999999999999, 142.19999999999987},
		ExpectedScaleGrad:   []float64{-79.3960426535744, -67.54511124603596, -52.36760260129374, -33.863516719347764, -12.03285360019801, 13.124386756155513, 41.60820434971281, 73.41859918047385, 108.55557124843864, 147.0191205536072, -154.24403049065916, -125.76021289710187, -93.94981806634075, -58.81284599837597, -20.349296693207414, 21.440829849164896, 66.55753362874096, 115.00081464552079, 166.77067289950435, 221.86710839069167},
	},
	{
		Dtype:               Float64,
		X:                   RangedFromWithStep(0.5, 5),
		XShape:              tensor.Shape{3, 2},
		Groups:              2,
		Channels:            2,
		ScaleInit:           RangedFromWithStep(0.3, 0.3),
		ScaleShape:          tensor.Shape{1, 2},
		BiasInit:            RangedFromWithStep(0.2, 0.2),
		BiasShape:           tensor.Shape{1, 2},
		ExpectedTrainResult: []float64{0.3, 0.66, 1.02, 0.3000000000000614, 0.6600000000001024, 1.0200000000001432, 0.3000000000000614, 0.6600000000001024, 1.0200000000001432},
		ExpectedCost:        0.6600000000000682,
		ExpectedInputGrad:   []float64{0, 0, 0, 0, 0, 0},
		ExpectedBiasGrad:    []float64{0.8999999999999998, 1.2},
		ExpectedScaleGrad:   []float64{0, 0},
	},
	{
		Dtype:               tensor.Float64,
		X:                   RangedFromWithStep(0.1, 2),
		XShape:              tensor.Shape{2, 20, 2, 3},
		Groups:              2,
		Channels:            20,
		ScaleInit:           RangedFromWithStep(0.3, 0.3),
		ScaleShape:          tensor.Shape{1, 20, 1, 1},
		BiasInit:            RangedFromWithStep(0.2, 0.2),
		BiasShape:           tensor.Shape{1, 20, 1, 1},
		ExpectedTrainResult: []float64{9841.673486015334, 22618.005443766822, 9841.673486015337, 22618.005443766833},
		ExpectedCost:        16229.83946489108,
		ExpectedInputGrad:   []float64{0.1515293450089557, 0.12627689001893583, 0.10102443502891593, 0.07577198003889604, 0.050519525048876135, 0.02526707005885622, 0.08706225857660721, 0.06310902214640476, 0.03915578571620232, 0.0152025492859999, -0.008750687144202546, -0.03270392357440498, 0.03818579486206826, 0.01553177699168333, -0.0071222408787015995, -0.029776258749086654, -0.05243027661947153, -0.07508429448985651, 0.004899953865339035, -0.016454845445228447, -0.03780964475579597, -0.05916444406636345, -0.08051924337693092, -0.10187404268749845, -0.012795264413580504, -0.03285084516433051, -0.05290642591508063, -0.07296200666583053, -0.09301758741658053, -0.11307316816733068, -0.014899859974690552, -0.033656222165622945, -0.05241258435655535, -0.07116894654748807, -0.08992530873842056, -0.10868167092935305, -0.0014138328179905402, -0.018870976449105625, -0.036328120080220724, -0.053785263711335796, -0.0712424073424508, -0.08869955097356588, 0.027662817056518692, 0.011504891985221227, -0.00465303308607623, -0.020810958157374132, -0.036968883228671603, -0.05312680829996905, 0.0723300896488379, 0.057471383137357515, 0.042612676625877345, 0.027753970114397397, 0.012895263602917234, -0.001963442908562707, 0.1325879849589662, 0.11902849700730399, 0.10546900905564112, 0.09190952110397847, 0.07835003315231559, 0.06479054520065294, 0.18028253628573188, 0.14216424392528837, 0.10404595156484509, 0.0659276592044018, 0.027809366843958294, -0.01030892551648499, 0.11657353921989033, 0.07975446541926479, 0.04293539161863924, 0.006116317818012806, -0.03070275598261296, -0.06752182978323873, 0.06845516487185921, 0.032935309631050735, -0.00258454560975796, -0.03810440085056643, -0.0736242560913749, -0.10914411133218338, 0.035927413241637396, 0.0017067765606464391, -0.032513860120344296, -0.06673449680133547, -0.10095513348232665, -0.1351757701633174, 0.018990284329225338, -0.013931133791948325, -0.04685255191312221, -0.0797739700342952, -0.11269538815546865, -0.14561680627664209, 0.017643778134622368, -0.013978421426733778, -0.04560062098808948, -0.07722282054944563, -0.10884502011080133, -0.1404672196721577, 0.03188789465782982, 0.0015649136562909671, -0.02875806734524744, -0.05908104834678629, -0.08940402934832425, -0.11972701034986333, 0.061722633898846135, 0.03269887145712547, 0.003675109015404354, -0.025348653426317203, -0.05437241586803854, -0.0833961783097592, 0.10714799585767265, 0.07942345197576883, 0.051698908093865015, 0.023974364211960975, -0.0037501796699428436, -0.03147472355184666, 0.1681639805343067, 0.14173865521222062, 0.11531332989013454, 0.08888800456804824, 0.06246267924596127, 0.03603735392387608, 0.15152934500895698, 0.12627689001893705, 0.10102443502891711, 0.07577198003889718, 0.050519525048877245, 0.02526707005885731, 0.08706225857660854, 0.0631090221464059, 0.03915578571620326, 0.015202549286000622, -0.008750687144201574, -0.032703923574404214, 0.03818579486206897, 0.015531776991684065, -0.007122240878700836, -0.029776258749086182, -0.05243027661947108, -0.07508429448985598, 0.00489995386533959, -0.016454845445228017, -0.037809644755795624, -0.05916444406636279, -0.0805192433769304, -0.101874042687498, -0.012795264413580032, -0.032850845164330345, -0.05290642591508021, -0.07296200666583053, -0.0930175874165804, -0.11307316816733026, -0.014899859974690344, -0.03365622216562292, -0.05241258435655549, -0.07116894654748807, -0.08992530873842064, -0.10868167092935321, -0.001413832817990457, -0.018870976449105736, -0.036328120080221016, -0.05378526371133585, -0.07124240734245113, -0.08869955097356597, 0.027662817056518296, 0.011504891985220755, -0.0046530330860767855, -0.020810958157374326, -0.03696888322867231, -0.05312680829996941, 0.07233008964883769, 0.057471383137357, 0.042612676625876755, 0.02775397011439651, 0.012895263602916707, -0.0019634429085635396, 0.1325879849589655, 0.119028497007303, 0.1054690090556405, 0.09190952110397754, 0.07835003315231504, 0.06479054520065208, 0.18028253628572166, 0.14216424392527838, 0.10404595156483509, 0.0659276592043927, 0.027809366843949412, -0.010308925516492984, 0.1165735392198819, 0.07975446541925724, 0.04293539161863169, 0.006116317818006145, -0.0307027559826194, -0.06752182978324495, 0.06845516487185321, 0.03293530963104452, -0.002584545609763289, -0.038104400850571984, -0.07362425609137979, -0.1091441113321876, 0.035927413241632955, 0.0017067765606428864, -0.03251386012034807, -0.06673449680133814, -0.1009551334823291, -0.13517577016331916, 0.018990284329222895, -0.013931133791950323, -0.04685255191312354, -0.07977397003429587, -0.11269538815546998, -0.1456168062766432, 0.017643778134622146, -0.013978421426733334, -0.045600620988088814, -0.0772228205494443, -0.10884502011079977, -0.14046721967215614, 0.031887894657831595, 0.0015649136562938537, -0.028758067345244775, -0.059081048346783405, -0.08940402934832115, -0.11972701034985977, 0.061722633898850354, 0.03269887145713035, 0.0036751090154094612, -0.025348653426312318, -0.05437241586803321, -0.08339617830975321, 0.10714799585767842, 0.07942345197577438, 0.05169890809387123, 0.02397436421196808, -0.003750179669935072, -0.031474723551838224, 0.16816398053431492, 0.1417386552122295, 0.1153133298901432, 0.08888800456805779, 0.062462679245971486, 0.036037353923885185},
		ExpectedBiasGrad:    []float64{114.30000000000001, 125.09999999999997, 135.89999999999992, 146.69999999999987, 157.49999999999983, 168.29999999999976, 179.09999999999974, 189.8999999999997, 200.69999999999965, 211.4999999999996, 222.29999999999956, 233.0999999999995, 243.89999999999952, 254.69999999999942, 265.49999999999943, 276.2999999999994, 287.09999999999934, 297.8999999999993, 308.6999999999992, 319.4999999999991},
		ExpectedScaleGrad:   []float64{-177.89766666727434, -151.3936080469977, -117.4060505221725, -75.93499409279866, -26.980438758876208, 29.457615479594864, 93.37916862261457, 164.78422067018292, 243.67277162229993, 330.04482147896556, -346.27639201961847, -282.3548388765987, -210.9497868290304, -132.0612358769134, -45.68918602024776, 48.166362740966505, 149.5054104067294, 258.3279569770409, 374.63400245190104, 498.42354683130975},
	},
	{
		Dtype:               Float32,
		X:                   RangedFromWithStep(0.5, 5),
		XShape:              tensor.Shape{3, 2},
		Groups:              2,
		Channels:            2,
		ScaleInit:           RangedFromWithStep(0.3, 0.3),
		ScaleShape:          tensor.Shape{1, 2},
		BiasInit:            RangedFromWithStep(0.2, 0.2),
		BiasShape:           tensor.Shape{1, 2},
		ExpectedTrainResult: []float32{0.3, 0.66, 1.02, 0.3, 0.66, 1.02, 0.3, 0.66, 1.02},
		ExpectedCost:        0.66,
		ExpectedInputGrad:   []float32{0, 0, 0, 0, 0, 0},
		ExpectedBiasGrad:    []float32{0.8999999999999998, 1.2},
		ExpectedScaleGrad:   []float32{0, 0},
	},
}

func TestGroupNorm(t *testing.T) {
	for i, tC := range groupNormTestCases {
		desc := fmt.Sprintf("Example #%d %v - %v", i+1, tC.Dtype, tC.XShape)
		t.Run(desc, func(t *testing.T) {
			rand.Seed(0)

			c := require.New(t)

			g := NewGraph()

			var initOpt NodeConsOpt

			switch v := tC.X.(type) {
			case []float32:
				initOpt = WithValue(
					tensor.New(
						tensor.Of(tensor.Float32),
						tensor.WithShape(tC.XShape...),
						tensor.WithBacking(v),
					),
				)
			case []float64:
				initOpt = WithValue(
					tensor.New(
						tensor.Of(tensor.Float32),
						tensor.WithShape(tC.XShape...),
						tensor.WithBacking(v),
					),
				)
			case InitWFn:
				initOpt = WithInit(v)
			}

			var err error

			x := NewTensor(g, tC.Dtype, tC.XShape.Dims(), WithShape(tC.XShape...), initOpt, WithName("x"))

			scale := NewTensor(g, tC.Dtype, tC.ScaleShape.Dims(), WithShape(tC.ScaleShape...), WithInit(tC.ScaleInit), WithName("scale"))
			bias := NewTensor(g, tC.Dtype, tC.BiasShape.Dims(), WithShape(tC.BiasShape...), WithInit(tC.BiasInit), WithName("bias"))

			fcWeight := NewTensor(g, tC.Dtype, 2, WithShape(tC.XShape[0], tensor.Shape(tC.XShape[1:]).TotalSize()), WithInit(tC.ScaleInit), WithName("fcWeight"))

			y, err := GroupNorm(x, scale, bias, tC.Groups, tC.Channels, 1e-5)
			c.NoError(err)

			if y.Dims() > 2 {
				y = Must(Reshape(y, fcWeight.Shape()))
			}

			wT := Must(Transpose(fcWeight, 1, 0))

			y = Must(Mul(y, wT))

			cost := Must(Mean(y))

			if _, err := Grad(cost, x, fcWeight, scale, bias); err != nil {
				t.Fatal(err)
			}

			m := NewTapeMachine(g, BindDualValues(x, fcWeight, scale, bias), TraceExec(), WithInfWatch(), WithNaNWatch())

			err = m.RunAll()
			c.NoError(err)

			// ioutil.WriteFile("gn.dot", []byte(g.ToDot()), 0644)

			c.NoError(m.Close())

			t.Logf("%v output:\n%v", desc, y.Value())
			t.Logf("%v cost:\n%v", desc, cost.Value())
			t.Logf("%v input grad:\n%v", desc, x.Deriv().Value())
			// t.Logf("%v output grad:\n%v", desc, y.Deriv().Value())
			// t.Logf("%v bias grad: %v", desc, bias.Deriv().Value())
			// t.Logf("%v scale grad: %v", desc, scale.Deriv().Value())

			c.InDeltaSlice(tC.ExpectedTrainResult, y.Value().Data(), 1e-3, "Wrong Output\ngot=%#v\nexpected=%#v", y.Value().Data(), tC.ExpectedTrainResult)
			c.InDelta(tC.ExpectedCost, cost.Value().Data(), 1e-3)

			c.InDeltaSlice(tC.ExpectedBiasGrad, bias.Deriv().Value().Data(), 1e-3, "Bias Grad doesn't match:\ngot=%#v expected=%#v", bias.Deriv().Value().Data(), tC.ExpectedBiasGrad)
			c.InDeltaSlice(tC.ExpectedScaleGrad, scale.Deriv().Value().Data(), 1e-3, "Scale Grad doens't match:\ngot=%#v expected=%#v", scale.Deriv().Value().Data(), tC.ExpectedScaleGrad)
			c.InDeltaSlice(tC.ExpectedInputGrad, x.Deriv().Value().Data(), 1e-3, "Input Grad doesn't match:\ngot=%#v expected=%#v", x.Deriv().Value().Data(), tC.ExpectedInputGrad)
		})
	}
}

func BenchmarkGroupNorm(b *testing.B) {
	b.StopTimer()

	tC := groupNormTestCases[0]

	rand.Seed(0)

	c := require.New(b)

	g := NewGraph()

	var initOpt NodeConsOpt

	switch v := tC.X.(type) {
	case []float32:
		initOpt = WithValue(
			tensor.New(
				tensor.Of(tensor.Float32),
				tensor.WithShape(tC.XShape...),
				tensor.WithBacking(v),
			),
		)
	case []float64:
		initOpt = WithValue(
			tensor.New(
				tensor.Of(tensor.Float32),
				tensor.WithShape(tC.XShape...),
				tensor.WithBacking(v),
			),
		)
	case InitWFn:
		initOpt = WithInit(v)
	}

	var err error

	x := NewTensor(g, tC.Dtype, tC.XShape.Dims(), WithShape(tC.XShape...), initOpt, WithName("x"))

	scale := NewTensor(g, tC.Dtype, tC.ScaleShape.Dims(), WithShape(tC.ScaleShape...), WithInit(tC.ScaleInit), WithName("scale"))
	bias := NewTensor(g, tC.Dtype, tC.BiasShape.Dims(), WithShape(tC.BiasShape...), WithInit(tC.BiasInit), WithName("bias"))

	y, err := GroupNorm(x, scale, bias, tC.Groups, tC.Channels, 1e-5)
	c.NoError(err)

	cost := Must(Mean(y))

	if _, err := Grad(cost, x, scale, bias); err != nil {
		c.NoError(err)
	}

	m := NewTapeMachine(g, BindDualValues(x, scale, bias), TraceExec(), WithInfWatch(), WithNaNWatch())

	for i := 0; i < b.N; i++ {
		b.StartTimer()
		err = m.RunAll()
		b.StopTimer()

		c.NoError(err)
	}
}
